#############################################################################################
# methods for computing model objects efficiently
# laplace transforms, state decompositions, etc
#############################################################################################


################################################################
# y(s) equiv -s^2 A'(s) object and methods
struct lapl_y
    # note: arrays are pointers to model objects; stored in mutable arrays
    M::Matrix{Float64}
    s::Vector{Float64}
    b_vec::Vector{Float64}
    a0_vec::Vector{Float64}

    # lapl_y specific objects
    sIM::Matrix{Float64}
    sIM_inv::Matrix{Float64}
    A::Vector{Float64}
    y::Vector{Float64}
    yp::Vector{Float64}
end
function lapl_y(M::Matrix{Float64}, s::Vector{Float64},
        b_vec::Vector{Float64}, a0_vec::Vector{Float64}
    )
    # placeholders
    J = size(M, 1)
    lapl_y(M, s, b_vec, a0_vec,
        zeros(J, J), zeros(J, J), zeros(J), zeros(J), zeros(J)
    )
end
function update_lapl_y!(Ly::lapl_y)::Nothing
    # update lapl_y objects (not derivs)
    M, b_vec, a0_vec = Ly.M, Ly.b_vec, Ly.a0_vec
    s = Ly.s[1]
    sIM, sIM_inv, A, y, yp = Ly.sIM, Ly.sIM_inv, Ly.A, Ly.y, Ly.yp
    sIM[:, :] = (s.*I + M)
    sIM_inv[:, :] = inv(sIM)
    A[:] = sIM_inv * (b_vec ./ s + a0_vec)
    y[:] = sIM_inv * (b_vec + s^2 .* A)
    yp[:] = 2 .* (sIM_inv * (s .* A - y))
    return nothing
end

##### derivs (inplace)
function deriv_lapl_y_wrtM!(dy::AbstractVector, Ly::lapl_y, dM::AbstractMatrix)::Nothing
    # compute directional derivative of y(s) wrt M
    M, b_vec, a0_vec = Ly.M, Ly.b_vec, Ly.a0_vec
    s = Ly.s[1]
    sIM, sIM_inv, A, y, yp = Ly.sIM, Ly.sIM_inv, Ly.A, Ly.y, Ly.yp
    dA = -sIM_inv * (dM * A)
    dy[:] = sIM_inv * (s^2 .* dA - dM * y)
    return nothing
end

function deriv_lapl_y_wrtM!(dy::AbstractVector, Ly::lapl_y, j_idx::Int, k_idx::Int)::Nothing
    # compute derivative of y(s) wrt m_{j,k} given j,k indices
    M, b_vec, a0_vec = Ly.M, Ly.b_vec, Ly.a0_vec
    s = Ly.s[1]
    sIM, sIM_inv, A, y, yp = Ly.sIM, Ly.sIM_inv, Ly.A, Ly.y, Ly.yp
    dA = -sIM_inv[:, j_idx] .* A[k_idx]
    dy[:] = s^2 .* sIM_inv * dA - sIM_inv[:, j_idx] .* y[k_idx]
    return nothing
end





################################################################
# Y(S_vec) object and methods
struct lapl_Y
    # note: arrays are pointers to model objects; stored in mutable arrays
    M::Matrix{Float64}
    S_vec::Vector{Float64}
    b_vec::Vector{Float64}
    a0_vec::Vector{Float64}

    # lapl_Y specific objects
    sIM_arr::Array{Float64, 3}
    sIM_inv_arr::Array{Float64, 3}
    A_arr::Matrix{Float64}
    Y::Matrix{Float64}
    Yp::Matrix{Float64}
end
function lapl_Y(M::Matrix{Float64}, S_vec::Vector{Float64},
        b_vec::Vector{Float64}, a0_vec::Vector{Float64}
    )
    # placeholders
    J = size(M, 1)
    lapl_Y(M, S_vec, b_vec, a0_vec,
        zeros(J, J, J), zeros(J, J, J), zeros(J, J), zeros(J, J), zeros(J, J)
    )
end
function update_lapl_Y!(LY::lapl_Y)::Nothing
    # update lapl_Y objects (not derivs)
    M, S_vec, b_vec, a0_vec = LY.M, LY.S_vec, LY.b_vec, LY.a0_vec
    sIM_arr, sIM_inv_arr, A_arr, Y, Yp = LY.sIM_arr, LY.sIM_inv_arr, LY.A_arr, LY.Y, LY.Yp
    for (idx, s) in enumerate(S_vec)
        if s!=0
            sIM = @view sIM_arr[:, :, idx]
            sIM_inv = @view sIM_inv_arr[:, :, idx]
            A = @view A_arr[:, idx]
            y = @view Y[:, idx]
            yp = @view Yp[:, idx]
            sIM[:, :] = (s.*I + M)
            sIM_inv[:, :] = inv(sIM)
            A[:] = sIM_inv * (b_vec ./ s + a0_vec)
            y[:] = sIM_inv * (b_vec + A .* s^2)
            yp[:] = 2 .* (sIM_inv * (s .* A - y))
        end
    end
    return nothing
end

##### derivs (inplace)
function deriv_lapl_Y_wrtM!(dY::AbstractMatrix, LY::lapl_Y, dM::AbstractMatrix)::Nothing
    # compute directional derivative of Y(s) wrt M
    M, S_vec, b_vec, a0_vec = LY.M, LY.S_vec, LY.b_vec, LY.a0_vec
    sIM_arr, sIM_inv_arr, A_arr, Y, Yp = LY.sIM_arr, LY.sIM_inv_arr, LY.A_arr, LY.Y, LY.Yp
    for (idx, s) in enumerate(S_vec)
        if s!=0
            sIM = @view sIM_arr[:, :, idx]
            sIM_inv = @view sIM_inv_arr[:, :, idx]
            A = @view A_arr[:, idx]
            y = @view Y[:, idx]
            yp = @view Yp[:, idx]
            dA = -sIM_inv * (dM * A)
            dY[:, idx] = sIM_inv * (s^2 .* dA - dM * y)
        end
    end
    return nothing
end

function deriv_lapl_Y_wrtM!(dY::AbstractMatrix, LY::lapl_Y, j_idx::Int, k_idx::Int)::Nothing
    # compute derivative of Y(s) wrt m_{j,k} given j,k indices
    M, S_vec, b_vec, a0_vec = LY.M, LY.S_vec, LY.b_vec, LY.a0_vec
    sIM_arr, sIM_inv_arr, A_arr, Y, Yp = LY.sIM_arr, LY.sIM_inv_arr, LY.A_arr, LY.Y, LY.Yp
    for (idx, s) in enumerate(S_vec)
        if s!=0
            sIM = @view sIM_arr[:, :, idx]
            sIM_inv = @view sIM_inv_arr[:, :, idx]
            A = @view A_arr[:, idx]
            y = @view Y[:, idx]
            yp = @view Yp[:, idx]
            dA = -sIM_inv[:, j_idx] .* A[k_idx]
            dY[:, idx] = s^2 .* sIM_inv * dA - sIM_inv[:, j_idx] .* y[k_idx]
        end
    end
    return nothing
end


################################################################
# X(s) objects and methods

struct lapl_X
    # note: arrays are pointers to model objects; stored in mutable arrays
    M::Matrix{Float64}
    s::Vector{Float64}
    b_vec::Vector{Float64}
    a0_vec::Vector{Float64}

    # lapl_X specific objects
    sIM::Matrix{Float64}
    sIM_inv::Matrix{Float64}
    sIM_half::Matrix{Float64}
    sIM_half_R::Matrix{Float64}
    sIM_half_Q::Matrix{Float64}
    A::Vector{Float64}
    Ap::Vector{Float64}
    X::Matrix{Float64}
    Xp::Matrix{Float64}
end
function lapl_X(M::Matrix{Float64}, s::Vector{Float64},
        b_vec::Vector{Float64}, a0_vec::Vector{Float64}
    )
    # placeholders
    J = size(M, 1)
    lapl_X(M, s, b_vec, a0_vec,
        zeros(J, J), zeros(J, J), zeros(J, J), zeros(J, J), zeros(J, J),
        zeros(J), zeros(J), zeros(J, J), zeros(J, J)
    )
end
function update_lapl_X!(LX::lapl_X)::Nothing
    # update update_lapl_X objects (not derivs)
    # pull out objects
    M, b_vec, a0_vec = LX.M, LX.b_vec, LX.a0_vec
    s = LX.s[1]
    sIM, sIM_inv = LX.sIM, LX.sIM_inv
    sIM_half, sIM_half_R, sIM_half_Q = LX.sIM_half, LX.sIM_half_R, LX.sIM_half_Q
    A, Ap, X, Xp = LX.A, LX.Ap, LX.X, LX.Xp
    # update A(s), A'(s) objects
    sIM[:, :] = (s[1].*I + M)
    sIM_inv[:, :] = inv(sIM)
    A[:] = sIM_inv * (b_vec ./ s + a0_vec)
    Ap[:] = -sIM_inv * (b_vec ./ s^2 + A)

    # rhs Q matrix
    eA = b_vec * A'
    q_rhs = eA + eA' + a0_vec * a0_vec'
    # lhs A matrix
    sIM_half[:, :] = (0.5*s).*I + M
    # save schur decomp
    sIM_half_R[:, :], sIM_half_Q[:, :] = schur(sIM_half)
    # solve lyap eqn (julia uses negative for q_rhs)
    X[:, :] = lyap(sIM_half_R, sIM_half_Q, -q_rhs)

    eAp = b_vec * Ap'
    qp_rhs = -X + eAp + eAp'
    # use pre-computed schur decomp
    Xp[:, :] = lyap(sIM_half_R, sIM_half_Q, -qp_rhs)
    return nothing
end


##### derivs (inplace)
function deriv_lapl_X_wrtM!(dX::AbstractMatrix, LX::lapl_X, dM::AbstractMatrix)::Nothing
    # compute directional derivative of X(s) wrt M
    # deriv of A(s) object
    # pull out objects
    sIM, sIM_inv = LX.sIM, LX.sIM_inv
    sIM_half, sIM_half_R, sIM_half_Q = LX.sIM_half, LX.sIM_half_R, LX.sIM_half_Q
    b_vec, A, X = LX.b_vec, LX.A, LX.X
    dA = -sIM_inv * (dM * A)

    # deriv of q_rhs terms
    dM_LX = dM * X
    d_eLA = b_vec * dA'
    q_rhs = d_eLA + d_eLA' - dM_LX - dM_LX'
    # use pre-computed schur decomp
    # solve lyap eqn (julia uses negative for q_rhs)
    dX[:, :] = lyap(sIM_half_R, sIM_half_Q, -q_rhs)
    return nothing
end

function deriv_lapl_X_wrtM!(dX::AbstractMatrix, LX::lapl_X, j_idx::Int, k_idx::Int)::Nothing
    # compute derivative of X(s) wrt m_{j,k} given j,k indices
    # deriv of A(s) object
    # pull out objects
    sIM, sIM_inv = LX.sIM, LX.sIM_inv
    sIM_half, sIM_half_R, sIM_half_Q = LX.sIM_half, LX.sIM_half_R, LX.sIM_half_Q
    b_vec, A, X = LX.b_vec, LX.A, LX.X
    dA = -sIM_inv[:, j_idx] .* A[k_idx]

    # deriv of q_rhs terms
    dM_LX_T = zero(sIM)
    dM_LX_T[:, j_idx] = X[:, k_idx]
    d_eLA = b_vec * dA'
    q_rhs = d_eLA + d_eLA' - dM_LX_T - dM_LX_T'
    # use pre-computed schur decomp
    # solve lyap eqn (julia uses negative for q_rhs)
    dX[:, :] = lyap(sIM_half_R, sIM_half_Q, -q_rhs)
    return nothing
end


################################################################
# state-space decomposition

struct StateDynamics
    # dynamics matrices
    Upsilon::Matrix{Float64}
    Gamma::Matrix{Float64}
    Omega::Matrix{Float64}

    # eigendecomp matrices
    Lam_diag::Vector{ComplexF64}
    W::Matrix{ComplexF64}
    Lam1_diag::Vector{ComplexF64}
    Lam2_diag::Vector{ComplexF64}
    W11::Matrix{ComplexF64}
    W21::Matrix{ComplexF64}

    # total number of variables and number of state variables
    N::Int
    J::Int
end
function StateDynamics(Upsilon::Matrix{Float64}, J::Int)
    # placeholders
    N = size(Upsilon, 1)
    Gamma = zeros(J, J)
    Omega = zeros(N-J, J)
    Lam_diag = zeros(ComplexF64, N)
    W = zeros(ComplexF64, N, N)
    Lam1_diag = zeros(ComplexF64, J)
    Lam2_diag = zeros(ComplexF64, N-J)
    W11 = zeros(ComplexF64, J, J)
    W21 = zeros(ComplexF64, N-J, J)

    StateDynamics(Upsilon, Gamma, Omega,
        Lam_diag, W, Lam1_diag, Lam2_diag, W11, W21,
        N, J
    )
end
function update_state_dynamics!(sdyn::StateDynamics)::Nothing
    # update state dynamics objects
    # pull out objects
    Upsilon, Gamma, Omega = sdyn.Upsilon, sdyn.Gamma, sdyn.Omega
    Lam_diag, W =  sdyn.Lam_diag, sdyn.W
    Lam1_diag, Lam2_diag, W11, W21 = sdyn.Lam1_diag, sdyn.Lam2_diag, sdyn.W11, sdyn.W21
    N, J = sdyn.N, sdyn.J

    # eigenvalue decomposition
    Ups_eig = eigen(Upsilon)
    # ordered by real part of eigenvalues, large to small
    # default is to sort small to large; reverse
    _eigval_real = real.(Ups_eig.values)
    _idx = sortperm(_eigval_real, rev=true)
    # assign
    #Lam_diag[:] = reverse(Ups_eig.values)
    #W[:, :] = reverse(Ups_eig.vectors; dims=2)
    Lam_diag[:] = Ups_eig.values[_idx]
    W[:, :] = Ups_eig.vectors[:, _idx]
    Lam1_diag[:] = Lam_diag[1:J]
    Lam2_diag[:] = Lam_diag[J+1:end]
    W11[:, :] = W[1:J, 1:J]
    W21[:, :] = W[J+1:end, 1:J]
    # solve for Gamma and Omega (always real)
    Gamma[:, :] = real.(W11 * Diagonal(Lam1_diag) / W11)
    Omega[:, :] = real.(W21 / W11)
    return nothing
end

##### derivs (inplace)
function deriv_state_dynamics_wrtUpsilon!(dGamma::AbstractMatrix,
        sdyn::StateDynamics, dUps::AbstractMatrix)::Nothing
    # deriv of dynamics matrix Gamma wrt Upsilon
    # pull out objects
    Upsilon, Gamma, Omega = sdyn.Upsilon, sdyn.Gamma, sdyn.Omega
    Lam_diag, W, Lam1_diag, W11, W21 = sdyn.Lam_diag, sdyn.W, sdyn.Lam1_diag, sdyn.W11, sdyn.W21
    N, J = sdyn.N, sdyn.J

    # deriv of eigendecomp
    dLam_diag, dW = deriv_simple_eigdecomp(Lam_diag, W, Upsilon, dUps)
    dLam1_diag = dLam_diag[1:J]
    dW11 = dW[1:J, 1:J]
    d_term = -Gamma * dW11 + dW11 * Diagonal(Lam1_diag) + W11 * Diagonal(dLam1_diag)
    dGamma[:, :] = real.(d_term / W11)
    return nothing
end

function deriv_jump_dynamics_wrtUpsilon!(dOmega::AbstractMatrix,
        sdyn::StateDynamics, dUps::AbstractMatrix)::Nothing
    # deriv of dynamics matrix Omega wrt Upsilon
    # pull out objects
    Upsilon, Gamma, Omega = sdyn.Upsilon, sdyn.Gamma, sdyn.Omega
    Lam_diag, W, Lam1_diag, W11, W21 = sdyn.Lam_diag, sdyn.W, sdyn.Lam1_diag, sdyn.W11, sdyn.W21
    N, J = sdyn.N, sdyn.J

    # deriv of eigendecomp
    dLam_diag, dW = deriv_simple_eigdecomp(Lam_diag, W, Upsilon, dUps)
    dW11 = dW[1:J, 1:J]
    dW21 = dW[J+1:end, 1:J]
    d_term = dW21 - Omega * dW11
    dOmega[:, :] = real.(d_term / W11)
    return nothing
end



################################################################
# alpha/theta objects and methods
function calc_alpha_tau(tau::Float64, alpha0::Float64, alpha1::Float64)::Float64
    # alpha(tau), given alpha0/alpha1
    return alpha0 * exp(-alpha1*tau)
end
function calc_theta_tau(tau::Float64, theta0::Float64, theta1::Float64)::Float64
    # theta(tau), given theta0/theta1
    return theta0 * theta1^2 * tau * exp(-theta1*tau)
end
function calc_Theta_tau(tau::Float64, Theta0::Vector{Float64},
        Theta1::Vector{Float64})::Vector{Float64}
    # Theta(tau), given vectors Theta0/Theta1
    J = length(Theta0)
    Theta_tau = zeros(J)
    for j=1:J
        t0, t1 = Theta0[j], Theta1[j]
        if t1==0
            continue
        else
            Theta_tau[j] = calc_theta_tau(tau, t0, t1)
        end
    end
    return Theta_tau
end

##### derivs
function deriv_alpha_tau(tau::Float64, alpha0::Float64, alpha1::Float64,
        dalpha0::Float64, dalpha1::Float64)::Float64
    # deriv of alpha(tau), given alpha0/alpha1
    dterm = dalpha0 - dalpha1 * alpha0 * tau
    return dterm * exp(-alpha1*tau)
end
function deriv_theta_tau(tau::Float64, theta0::Float64, theta1::Float64,
        dtheta0::Float64, dtheta1::Float64)::Float64
    # deriv of theta(tau), given theta0/theta1
    dterm = theta1 * dtheta0 + (2 * theta0 - theta0 * theta1 * tau) * dtheta1
    return dterm * theta1 * tau * exp(-theta1*tau)
end
function deriv_Theta_tau(tau::Float64, Theta0::Vector{Float64}, Theta1::Vector{Float64},
        dTheta0::Vector{Float64}, dTheta1::Vector{Float64})::Vector{Float64}
    # deriv of Theta(tau), given vectors Theta0/Theta1
    J = length(Theta0)
    dTheta_tau = zeros(J)
    for j=1:J
        t0, t1 = Theta0[j], Theta1[j]
        dt0, dt1 = dTheta0[j], dTheta1[j]
        if t1==0
            continue
        else
            dTheta_tau[j] = deriv_theta_tau(tau, t0, t1, dt0, dt1)
        end
    end
    return dTheta_tau
end
